{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "b9c4312a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "a00a8724",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.loadtxt(\"../data/movielens_100k.csv\",delimiter=',',dtype=int)\n",
    "data[:,:2] = data[:,:2] - 1\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(data)\n",
    "\n",
    "ratio = 0.8\n",
    "split = int(len(data)*ratio)\n",
    "train, test = data[:split], data[split:]\n",
    "user_train, item_train, y_train = train[:,0], train[:,1], train[:,2]\n",
    "user_test, item_test, y_test = test[:,0], test[:,1], test[:,2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "f40dd4e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "users = set()\n",
    "items = set()\n",
    "for user, item, y in data:\n",
    "    users.add(user)\n",
    "    items.add(item)\n",
    "user_num = len(users)\n",
    "item_num = len(items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "8b294dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 16\n",
    "class MF():\n",
    "    def __init__(self, N, M, d):\n",
    "        self.user_params = np.ones((N,d))\n",
    "        self.item_params = np.ones((M,d))\n",
    "    def pred(self, user_id, item_id):\n",
    "        user_param = self.user_params[user_id]\n",
    "        item_param = self.item_params[item_id]\n",
    "        return user_param @ item_param.T\n",
    "    def update(self, P_grad, Q_grad, lr):\n",
    "        self.user_params -= lr * P_grad\n",
    "        self.item_params -= lr * Q_grad\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "157399cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_generator(x, y, z, batch_size, shuffle=True):\n",
    "    batch_count = 0\n",
    "    if shuffle:\n",
    "        idx = np.random.permutation(len(x))\n",
    "        x = x[idx]\n",
    "        y = y[idx]\n",
    "        z = z[idx]\n",
    "    while True:\n",
    "        st = batch_count * batch_size\n",
    "        ed = min(st + batch_size, len(x))\n",
    "        if st>=ed:\n",
    "            break\n",
    "        batch_count+=1\n",
    "        yield x[st:ed], y[st:ed], z[st:ed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "f7da6595",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "lbd = 1e-4\n",
    "lr = 0.01\n",
    "max_training = 35\n",
    "model = MF(N = user_num, M = item_num,d = d)\n",
    "gen = batch_generator(user_train, item_train, y_train, batch_size = batch_size, shuffle=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "b80ec5f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 35/35 [00:45<00:00,  1.31s/it]\n"
     ]
    }
   ],
   "source": [
    "train_losses = []\n",
    "test_losses = []\n",
    "for epoch in tqdm(range(max_training)):\n",
    "    gen = batch_generator(user_train, item_train, y_train, batch_size = batch_size, shuffle=True)\n",
    "    P = model.user_params\n",
    "    Q = model.item_params\n",
    "    P_grad = np.zeros_like(P)\n",
    "    Q_grad = np.zeros_like(Q)\n",
    "    for user_batch, item_batch, y_bacth in gen:\n",
    "        for user, item, y in zip(user_batch, item_batch, y_bacth):\n",
    "            P_grad[user] = P_grad[user] + (model.pred(user,item) - y) * Q[item] + lbd * P[user]\n",
    "            Q_grad[item] = Q_grad[item] + (model.pred(user,item) - y) * P[user] + lbd * Q[item]\n",
    "    model.update(P_grad/len(user_batch), Q_grad/len(item_batch), lr)\n",
    "    \n",
    "    train_loss = 0\n",
    "    for user,item,y in zip(user_train, item_train, y_train):\n",
    "        train_loss += (model.pred(user, item) - y)**2\n",
    "    train_loss = np.sqrt((train_loss / len(user_train)))\n",
    "    train_losses.append(train_loss)\n",
    "\n",
    "    test_loss = 0\n",
    "    for user,item,y in zip(user_test, item_test, y_test):\n",
    "        test_loss += (model.pred(user, item) - y)**2\n",
    "    test_loss = np.sqrt((test_loss / len(user_test)))\n",
    "    test_losses.append(test_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "9d2a7704",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f101c70da90>]"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhYAAAGdCAYAAABO2DpVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8ekN5oAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA4WklEQVR4nO3deXiU5b3/8c9kh5AFwhoIIewCgoKKqBAr4FrFpa0i59Ttp3XXU1dsrcvpKag9XeyxaovV2qq4VKTVuqEmiIKyyiZ7IIGwC0kIJIHk+f3xZZiZkJCEzMyTzLxf1/VcM/PMJPlmnMt8uJ/vfd8ex3EcAQAABEGM2wUAAIDIQbAAAABBQ7AAAABBQ7AAAABBQ7AAAABBQ7AAAABBQ7AAAABBQ7AAAABBExfOH1ZTU6Pi4mKlpKTI4/GE80cDAIDj5DiOysrKlJmZqZiYY49JhDVYFBcXKysrK5w/EgAABElRUZF69OhxzNeENVikpKRIssJSU1PD+aMBAMBxKi0tVVZW1pG/48cS1mDhvfyRmppKsAAAoJVpTBsDzZsAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoCBYAACBoIiNYlJZKv/iFdMMNkuO4XQ0AAFHL4zjh+0tcWlqqtLQ0lZSUBHd30wMHpLZt7f6uXVJGRvC+NwAAUa4pf78jY8SiTRupa1e7X1Dgbi0AAESxyAgWktS7t90SLAAAcE1EBIudO6UZ3+RIkmrWEywAAHBLRASLDh2kbw9YsChfvsHlagAAiF4RESxiY6XSDAsWB9cwYgEAgFsiIlhI0sEs67GILSRYAADglogJFvH9bcQieedGqbra3WIAAIhSERMs0of00EHFKa7moFRc7HY5AABEpYgJFr36xGqTsu0BU04BAHBFxASLnBypQHY5RBuYGQIAgBsiMlgcWsuIBQAAboiYYNGpk7Q53maG7F9BsAAAwA0REyw8Hml/FxuxqF7LpRAAANwQMcFCkmqyLVgkbGHEAgAAN0RUsEg8wS6FJJcUSxUVLlcDAED0iahg0fmEDJWpnT3YuNHVWgAAiEYRFSxyent8U05ZywIAgLCLrGCRI22QXQ4hWAAAEH4RFyy8IxZVq5gZAgBAuEVUsEhNlXa0tWBxYCUjFgAAhFtEBQtJOtDNLoU4XAoBACDsIi5YeHrbiEWbYi6FAAAQbhEXLNoO6iVJSqwokfbscbcYAACiTMQFi+79k7VNXewBl0MAAAiriAsWbJ8OAIB7IjpYOBsYsQAAIJwiLlj06uULFhWrCBYAAIRTxAWLpCTpuzSbclpJsAAAIKwiLlhIUlUPG7GI2UiPBQAA4RSRwSKurwWLtjs2SjU17hYDAEAUichgkTo4S4cUq7jqKmnrVrfLAQAgakRksMjuE6dC9bQHTDkFACBsIjJYBKxlwSJZAACETcQGiw2ymSE1rGUBAEDYRGSw6NFD2uQ5vH36Ci6FAAAQLhEZLOLipLKOFiwOrWHEAgCAcInIYCFJh3rapZC4zQQLAADCJWKDRcKAw2tZ7NkiVVa6XA0AANEhYoNFxxM6qVxt5XEcadMmt8sBACAqRGywyOntOTIzhCmnAACER+QGC/+1LFgkCwCAsIiKYFG9jhELAADCIWKDRZcuUlGcXQo5sJJgAQBAOERssPB4pP1dbMSiZj2XQgAACIeIDRaS7HqIpIQtjFgAABAOER0skk6wYJG0f4+0d6+7xQAAEAUiOlhk9m+nHepkD5hyCgBAyDUpWDz66KPyeDwBx8CBA0NVW7OxfToAAOEV19QvGDx4sGbNmuX7BnFN/hZhk5MjrVZvjdTXBAsAAMKgyakgLi5OXbt2DUUtQZeTI31weMTi4JoCxbtcDwAAka7JPRZr165VZmamevfurUmTJqmwsLDe11ZWVqq0tDTgCKf27aVtbSxYVHzLlFMAAEKtScFi5MiReumll/TBBx/o2WefVUFBgUaPHq2ysrI6Xz9lyhSlpaUdObKysoJSdFNUZdJjAQBAuHgcx3GO94v37t2r7Oxs/eY3v9ENN9xw1POVlZWq9NuyvLS0VFlZWSopKVFqaurx/tgmueW8DXr2oz46FJ+kuIpyKSaiJ8IAABB0paWlSktLa9Tf72Z1Xqanp6t///5at25dnc8nJiYqMTGxOT+i2dqdkKXqj2IUd7BC2rZNysx0tR4AACJZs/75vm/fPq1fv17dunULVj1Bl903XkU6fAmGyyEAAIRUk4LFvffeq/z8fG3cuFFffvmlLrvsMsXGxmrixImhqq/ZcnKkDbLNyAgWAACEVpMuhWzevFkTJ07U7t271alTJ5111lmaN2+eOnXqFKr6mi0nR5qrHEmfSRuYGQIAQCg1KVhMnz49VHWETK9e0quH17KoXFUgdzs+AACIbBE/RaJtW2l3ql0KqVrDpRAAAEIp4oOFJB3KshGL2E1cCgEAIJSiIljE9bNg0Wb3ZqmqyuVqAACIXFERLNoP7KL9aiOP40jHWIIcAAA0T1QEi5zeHt/26cwMAQAgZKIjWOTIFyxYywIAgJCJmmDhXSTL2UCwAAAgVKIiWPTsKW30sH06AAChFhXBIj5eKsuwYFG9lhELAABCJSqChSRVZ9ulkLjNBAsAAEIlaoJF4kAbsUjat1sqLXW5GgAAIlPUBItu/VO0Sxn2gJkhAACERNQEC7ZPBwAg9KIqWLCWBQAAoRWVwaJmHVNOAQAIhagJFt26SZvjvGtZMGIBAEAoRE2wiImRyruw+iYAAKEUNcFCkl0PkZS4tUByHJeLAQAg8kRVsEg+oadq5FFc1QFp+3a3ywEAIOJEVbDo2TdBRcqyB8wMAQAg6KIqWARMOd3AzBAAAIIteoMFIxYAAARdVAWLXr18q28eYpdTAACCLqqCRUaGtDXRRiwqV3EpBACAYIuqYOHxSFXdLVh4NjJiAQBAsEVVsJCkmL52KSRpZ5F08KDL1QAAEFmiLlikD+yqA0pSjFMjFRa6XQ4AABEl6oJFTm+PNqqXPWBmCAAAQRV9wSLHNzOEYAEAQHBFZbBgkSwAAEIjqoNF1RpGLAAACKaoCxbt2km7UuxSyEGCBQAAQRV1wUKSDvW0EYu4Qi6FAAAQTFEZLOL7WbBILN0l7dvncjUAAESOqAwWXQekabc62ANmhgAAEDRRGSzY5RQAgNAgWDDlFACAoInaYOFdJMvZwIgFAADBEpXBomdPaaN3LYvVBAsAAIIlKoNFYqJU1tGCRfU6LoUAABAsURksJKkm24JFwpYCyXFcrgYAgMgQtcEiaUC2auRRXOV+aedOt8sBACAiRG2wyOqbqC3qbg+YGQIAQFBEbbDIyZHWqp89WL7c3WIAAIgQUR0svtCZ9iA/391iAACIEFEdLPKVK0ly8vNp4AQAIAiiNlh07y4tiBulKsXLU1QkbdzodkkAALR6URssYmOljtnJmq9T7URenqv1AAAQCaI2WEjSwIFSns62B/RZAADQbFEdLEaP9vVZMGIBAEDzRXWwyM2VvtQZOqg4adMm+iwAAGimqA4WI0ZISm6nBTrFTnA5BACAZonqYBEfL51xBn0WAAAES1QHC8kuh9BnAQBAcBAscm0FzkOKlQoKpKIit0sCAKDVivpgceqp0qGkFC3UCDvB5RAAAI5b1AeLxETp9NO5HAIAQDBEfbCQ7HIIDZwAADQfwUIWLOboLFUrRlq3Ttqyxe2SAABolQgWskshlQmpWqThdoJRCwAAjgvBQlKbNtJpp9FnAQBAcxEsDqPPAgCA5iNYHObts6iRR1qzRtq61e2SAABodQgWh51xhlQel64lOslOMGoBAECTESwOS06WTjmFyyEAADQHwcLPmDE0cAIA0BwECz+5udLnGm19FqtWSdu3u10SAACtCsHCz1lnSSUxHbRUQ+0El0MAAGgSgoWf1FTp5JPpswAA4HgRLGrJzaXPAgCA40WwqCU3V5qtMfZg5Upp5053CwIAoBVpVrCYOnWqPB6P7r777iCV477Ro6U9ngwt1Yl2YvZsdwsCAKAVOe5gMX/+fD3//PMaOnRoMOtxXfv20tChXA4BAOB4HFew2LdvnyZNmqQ///nPat++fbBrct2YMTRwAgBwPI4rWNx222266KKLNG7cuGO+rrKyUqWlpQFHaxDQZ7FsmbRrl7sFAQDQSjQ5WEyfPl2LFi3SlClTGnztlClTlJaWduTIyso6riLDbcwYaZc6abkG24nPP3e3IAAAWokmBYuioiLdddddeuWVV5SUlNTg6ydPnqySkpIjR1FR0XEXGk6dOkmDBtFnAQBAUzUpWCxcuFA7duzQ8OHDFRcXp7i4OOXn5+vpp59WXFycqqurA16fmJio1NTUgKO1yM2lzwIAgKaKa8qLx44dq2XLlgWcu+666zRw4EA98MADio2NDWpxbsrNle589nCfxdKl0nffSR06uFsUAAAtXJOCRUpKioYMGRJwLjk5WRkZGUedb+1yc6Ud6qKVOkGDnG+tz2LCBLfLAgCgRWPlzXp07Sr17+/XZ8HlEAAAGtSkEYu65EVwY+OYMVL+mlzdoudo4AQAoBEYsTiGgA3JliyR9u51sxwAAFo8gsUx5OZK29RNq9VfchzWswAAoAEEi2PIypJycph2CgBAYxEsGhBwOYQ+CwAAjolg0YCAYLF4sVRS4m5BAAC0YASLBuTmSsXqrrXqK9XUSHPmuF0SAAAtFsGiAb16Wa8FfRYAADSMYNEAj6fW5RCCBQAA9SJYNMKYMX7BYuFCqazM3YIAAGihCBaNkJsrbVaWNqi3VF0tffGF2yUBANAiESwaoV8/2zskj2mnAAAcE8GiEbx9FjRwAgBwbASLRgpo4Jw/X9q3z92CAABogQgWjZSbKxUqWxs9vazP4ssv3S4JAIAWh2DRSCecIHXqJOU5h0ctZs1ytyAAAFoggkUjeTw27fQ9XWQnXn1VOnTI3aIAAGhhCBZNMGaM9E9dopL4DGnLFunDD90uCQCAFoVg0QS5uVKVEvWy82M7MW2auwUBANDCECya4MQTpfbtpecO3WAn3n1X2rbN3aIAAGhBCBZNEBMjjR4trdRgbel5uvVYvPyy22UBANBiECyaaNw4u/2L5//ZnWnTJMdxryAAAFoQgkUTTZwoxcdLT2y6UtVt20lr10qff+52WQAAtAgEiybq2FGaMEEqVzt9lXOVnXzhBXeLAgCghSBYHIcbDvduPlZ0+M6bb0p797pWDwAALQXB4jiMHy/16CF9VDpSJT0GSwcOSK+95nZZAAC4jmBxHGJjpWuvlSSPXkv2a+IEACDKESyO03XX2e3Dq/9DTny8tGiRtHixu0UBAOAygsVx6t1b+t73pF3qqBX9L7OTNHECAKIcwaIZvE2cU3cevhzy979bvwUAAFGKYNEMl18upaVJr+4YqwNdsqWSEukf/3C7LAAAXEOwaIY2baSrr5YcxejdztfbSS6HAACiGMGimbyXQyavvk6OxyPl5dlqnAAARCGCRTMNHy4NHSqtr8rSphPOt5N/+Yu7RQEA4BKCRTN5PL5Ri/87cPjOSy/ZzqcAAEQZgkUQTJokJSRITxdcrIPtO0nbtkn//rfbZQEAEHYEiyDIyJAuvVQ6qATlZ19jJ1mJEwAQhQgWQeK9HPLQ+sN33ntPKi52ryAAAFxAsAiSsWOlnj2l+WUDtbP/mVJNjfTXv7pdFgAAYUWwCBLfxmTSS3GHV+J84QULGAAARAmCRRB5NyZ7dOUPVdMuRVq/XsrPd7coAADCiGARRL162SWR/UrWov4T7SRNnACAKEKwCDJvE+ejmw9fDvnHP6Q9e9wrCACAMCJYBNlll0np6dJ7O05RWc5QqbJSeuUVt8sCACAsCBZBlpRkC2ZJHr2Vdnj4Yto0yXHcLAsAgLAgWITAkY3JVvyHnMRE6ZtvpEWL3C0KAIAwIFiEwMkn27H9YAetGXy5naSJEwAQBQgWIXL99Xb76z2Hhy9efVUqL3evIAAAwoBgESKTJkmJidILBd9TZfccqbRUeustt8sCACCkCBYh0r69dPnlkqMYfdj98KjFH/9IEycAIKIRLELIeznkvm+vl9OmjfT119Lbb7tbFAAAIUSwCKFzzpGys6U1Zd20/Lx77eQDD9jaFgAARCCCRQjFxPj2D3lg9/1S1662f8gzz7hbGAAAIUKwCLHrrpM8Hun9z9tp512/tJP//d/S7t3uFgYAQAgQLEKsZ09p/Hi7/4eya6WhQ6W9ey1cAAAQYQgWYeBt4nzuz7Ha//iv7cEzz0hr1rhXFAAAIUCwCIPLL5f69ZN27pSmLBgvXXihdOiQNXICABBBCBZhEB8vPfGE3f/f/5W23fOUFBsrvfOOlJ/vam0AAAQTwSJMLr1UOvNM6cABafLfBkk33mhP3HOPVFPjam0AAAQLwSJMPB4brZCkv/5VWvGjx6SUFGnhQttHBACACECwCKORI6Urr7RVvf9rSmfpoYfsicmTpf373S0OAIAgIFiE2a9+ZT0XH38sfTz4bpuPunmz9Nvful0aAADNRrAIs969pTvusPs/fShJNf8zxR5MnSpt2+ZeYQAABAHBwgU/+5mUni4tXy69eOAq6bTTpH37pEcecbs0AACahWDhgg4dpIcftvsPPxKjA//zG3swbZqlDQAAWimChUtuu03KyZG2bpWe+vJM6YorbNrpvfe6XRoAAMeNYOGSxERpyuH2iieflHbe+4R1dX74oR0AALRCBAsX/ehHNgW1vFz6+Yt9fF2d99xjS34DANDKECxc5PFIvz68J9m0adK3V/zcGjBWrJD+8hd3iwMA4DgQLFx21lnSZZdZe8V9v2rvmxny8MNSWZm7xQEA0EQEixZg6lQpLk567z3pswE321aoO3b4di4DAKCVIFi0AP37SzffbPfvmZygmqlP2oP//V+pqMi9wgAAaKImBYtnn31WQ4cOVWpqqlJTUzVq1Ci9//77oaotqvziF1JqqrR4sfTKvgnSmDFSRYVvPxEAAFqBJgWLHj16aOrUqVq4cKEWLFigc845RxMmTNCKFStCVV/U6NTJ9iKTpJ/93KOKXx1eNOvvf7drJAAAtAIex3Gc5nyDDh066KmnntINN9zQ4GtLS0uVlpamkpISpaamNufHRqQDB6QBA+zqx5Qp0oNb75Keflpq396GMrKz3S4RABCFmvL3+7h7LKqrqzV9+nSVl5dr1KhRdb6msrJSpaWlAQfq16aN7X4q2e3O+5+yfUT27LFFL6qq3C0QAIAGNDlYLFu2TO3atVNiYqJuvvlmzZgxQ4MGDarztVOmTFFaWtqRIysrq9kFR7qrr5aGD7eZpo9PTZDeeMNGLL7+WrrvPrfLAwDgmJp8KaSqqkqFhYUqKSnRW2+9pWnTpik/P7/OcFFZWanKysojj0tLS5WVlcWlkAZ8+qk0dqxNQV2xQuq/5l3p4ovtyTfflH7wA3cLBABElaZcCml2j8W4cePUp08fPf/880EtLNp9//vWs3nppdKMGZIefNDWtUhJkRYutLUuAAAIg7D0WHjV1NQEjEogOJ58UoqJkd55R3r9dUm//KU0erRdI/nhD63TEwCAFqZJwWLy5MmaPXu2Nm7cqGXLlmny5MnKy8vTpEmTQlVf1Bo0yDf99MYbpbUFcdL06VLnztI330h33ulugQAA1KFJwWLHjh368Y9/rAEDBmjs2LGaP3++PvzwQ40fPz5U9UW1Rx+1dbLKymxSSEWHTOnVV233smnTpJdfdrtEAAACNLvHoinosWi64mLppJOknTulW26R/vhHSY8/bpuVtW1rs0UGD3a7TABABAtrjwVCKzPTFt/0eKRnnz3cb/Hzn0vnnivt328zRPbtc7tMAAAkESxahXPP9W0ZcuON0tr1MZY2uneXVq2SfvITKXwDTwAA1Itg0Uoc1W+R0smGL2Jjre/iT39yu0QAAAgWrUVcnPTaa7ZZ2ZIl0k9/KunMM6WpU+0Fd94pLVrkZokAABAsWpM6+y3uuUeaMMH2EfnhD6W9e90uEwAQxQgWrcxR/RbrPNKLL0q9ekkbNkjXXUe/BQDANQSLVuiofos27W0PkYQEW6rzd79zuUIAQLQiWLRCdfZbnHKK9Nvf2gvuv1/6+GM3SwQARCmCRStVZ7/FLbdIEydKhw5Jl1xi26QCABBGBItWrN5+i4sukioqbIvUvDxXawQARBeCRSt3VL+Fkyi99ZZ0/vm2A+pFF0mzZ7tdJgAgShAsWrk6+y2SkqQZM3zLfl94oTRnjtulAgCiAMEiAtTZb5GUZDNExo2TysulCy6QvvzS7VIBABGOYBEhavdbrFghqU0baeZM6ZxzbKOy88+X5s1ztU4AQGQjWESQRx+VcnOt32LsWGn1atnW6v/8p3T22fbEeefZVusAAIQAwSKCxMVJb78tDRsmbd9uAxXr1klKTpbefVcaPVoqLbXhjQUL3C4XABCBCBYRpkMHWxtr8GCpuNjCxcaNsnDx73/bxmUlJdL48WxaBgAIOoJFBOrUSfrkE2nAAKmoyMJFUZGkdu2k99+XRo2yzcrGjbOpJAAABAnBIkJ16WILb/btKxUUWLgoLpaUkiJ98IE0cqS0Z4+Fi6VL3S4XABAhCBYRLDPTwkVOjvVajB1rvRdKTZU+/FA69VRp9257Yvlyt8sFAEQAgkWEy8qycJGVJa1aZRli505JaWnSRx9JI0ZIu3bZkMayZW6XCwBo5QgWUaBXLwsXmZm2vsX48dJ330lKT7dwcfLJljbOOMOmlQAAcJwIFlGib18LF126SN98YzNO9+6VTSOZNUv63vdsEa0rrpAefliqqXG7ZABAK0SwiCIDBthskY4dpYULbZXv0lJZuPjoI+nuu+2Fv/ylNGGCTUsFAKAJCBZRZvBgG6Do0MFW977oIhuoUFyc9NvfSi+/LCUm2oJaI0daYwYAAI1EsIhCw4bZIlrp6bbp6cUX2yaokqT//E872aOHrQl+2mm2JDgAAI1AsIhSw4fbjNOUFCkvT7r0Uqmi4vCTp5xiS36PHm37i0yYID3+OH0XAIAGESyi2Gmn2UKcyck2gjF+vLRjx+Enu3Sxhozbb7fHjzxijZ2lpa7VCwBo+QgWUe7MM6X33rM1s+bMscGKI1uIxMdLf/iD9MILUkKC9M470umnS2vWuFkyAKAFI1hAubm2k3r//ranyFlnSdOn+73g+uul2bNtIYxvv7Whjn//27V6AQAtF8ECkmwq6ldf2RTUAwekiROlyZOl6urDLxg50uaonnGGTUP9/velX/1KchxX6wYAtCwECxyRni7961/SAw/Y46lTay1n0bWr9Nln0k9+YoHiZz+zrs9t21yqGADQ0hAsECA21gLFK69ISUnWfzFypF9bRUKC9Nxz0vPPWw/GP/8pDRpk618wegEAUY9ggTpdffXRy1l88IHfC266SZo/3+at7tkjXXONdOGFUmGhazUDANxHsEC9Royw5SzOPNMuh1x0kfTUU34DE8OGWWPGlCm2WucHH0hDhthoBmteAEBUIljgmLp0sc3LbrzRssL990v/8R/W4CnJlgJ/8EFpyRJp1ChbUOvmm21/9vXr3SwdAOACggUalJBggxDPPGM54tVXbVHOzZv9XjRwoPT559Lvfie1bWvLeZ54ou0/cmRqCQAg0hEs0Cgej3TrrbZCZ0aGzTw95RTLEkfExkp33SUtW2bbsB84IP30p7YwxrffulY7ACB8CBZokrPPtr6LoUOl7dvt8X33+V0akaTevW058Oeft81I5s2TTjrJ1r04eNCdwgEAYUGwQJP16iV9+aX04x9b38Wvf2254csv/V7k8djMkRUrbLZIVZWtezFypPVjAAAiEsECxyU5WfrrX21BrcxMW+firLPsyseRLdglKStLevddW+eifXtp8WKbbnLjjVJxsWv1AwBCg2CBZvn+96Xly6Vrr7VpqL/9rc1CDei98Hik//xPaeVK6Yc/tGGOadOkvn2ln/+cHVMBIIIQLNBs7dtLL75o+5J17y6tW2cbm911l1Re7vfCrl2lN96wlbfOOMMaM/7nf6Q+fWwX1aoq134HAEBwECwQNBdcYC0VN9xgoxdPP21Nnvn5tV545pkWLmbMsN3Pdu2S7rzTlgZ/4w2WBgeAVoxggaBKS7OrHB98YO0VGzbYzJE77pD27fN7ocdjG5gtX257j3TpYgtqXXmlNXjm5bnzCwAAmoVggZA47zzLDDfdZI//7/9s9OKzz2q9MC7Odktdt0569FHrCp0/39bB8DZwAABaDYIFQiY11Zay+PhjKTtbKiiQzjnHVvzetavWi9u1kx55xEYtbrnFFtt67z3rBL3hhlrLfAIAWiqCBUJu3DhbjPPWW+3x889bv+bUqbUW1pLsksgf/2gzSK64wmaQ/OUvUr9+0u2327UVAECLRbBAWKSk2F4jeXnSySfbDNPJk6138+WX69gMtX9/6a23bNWts86SKirsG/TrJ02cKC1a5MavAQBoAMECYZWba0uC/+1vUs+eUlGRdM01tmbWrFl1fMGoUdLs2bZE+HnnWQKZPt2+YPx4u87CLBIAaDEIFgi7mBjben31aumJJ2wmyZIllhMuvLCOfk2Px5ozPvjAVu68+mrrwZg1Szr3XAsZ06dLhw658esAAPwQLOCapCTp/vttQshdd9kEkffft37N//f/6lnx+6STpFdesS+6807bon3xYrs80r+/XS4JWFMcABBOBAu4rmNH6Xe/s53Vf/ADu9rxwgvWTvGLX0hlZXV8Ua9e0u9/LxUWSo89Zt+koMAaPHv2tHNHTT0BAIQawQItRt++0ptvWr/mGWfYwMN//7cFjOeeq2fF74wMSx+bNtloRU6OtHu3rYmRnW0LaSxYQB8GAIQJwQItzqhRtuL3P/5hYWP7dlvaok8f2+QsYAVPr7ZtbT7rmjXWbzF8uCWTP/9ZOvVU68N47jk2PAOAEPM4Tvj+KVdaWqq0tDSVlJQoNTU1XD8WrdjBg7buxa9+JW3daufat7crHnfcIXXqVM8XOo7NJvnzn23aamWlnW/bVrrqKhvJOO00awwFABxTU/5+EyzQKlRW2hTVp56yQQlJatPGFuW85x5ruajX7t32xX/6kzVyeA0dagFj0iQpPT2E1QNA60awQMSqrpbeecdW7VywwM7FxtogxP33W1aol+NIX3xhoxhvvGGLbkmWUH70IwsZo0YxigEAtRAsEPEcxzY0mzrV1sjyuuAC6cEHpdGjG8gHe/ZIf/+7jWL4L5wxeLB0/fW2y2r37iGrHwBaE4IFosqiRbbQ1ltv+ZYGP/106YEHpEsusQW56uU40rx5FjBef923eYnHI40ZY+tj/OAHNvsEAKIUwQJRad066de/ll56ydermZMj3XijdN11UteuDXyDvXul116zBbi++MJ3Pi7OlgWdOFG69FLb+AQAogjBAlFt2zZbO+vZZ6WSEjsXF2eZ4KabpLFjGxjFkGzhrddft6CxeLHvfFKS9P3vW8i48EJ7DAARjmAByJaxeOMNm646b57vfJ8+Nopx7bW2S3uDVq+2gPHaa74pKZKNXFx2mYWMsWOl+Phg/woA0CIQLIBali61Noq//c23RlZ8vI1i/OQn0ve+14hRDMex3dJee80W4Soq8j2XkSFddJE1dZx3ntSuXYh+EwAIP4IFUI/ycrvC8ac/SV995Tvft69vFKNz50Z8o5oaae5cCxlvvCHt3Ol7LiHBRjAuuUS6+GJmlwBo9QgWQCMsWWIB4+9/9210Fh9veeDqq5vQQnHokG1wMnOmHevXBz5/yin2TS+5xBbaYJ0MAK0MwQJogvJyu7Lx/PPS/Pm+86mp0uWXWwvFOedYA2iDHMdW9/znP+2YNy9wA7TsbF/IyM2lLwNAq0CwAI7TN9/YbNPaLRSdO9vinBMnNnFxzu3bpXfftZDx8ce+dTIkSy5jx0rnnmtH795B/V0AIFgIFkAz1dTYUhbeFordu33PZWdbwJg4UTrxxCaEjP37pVmzLGT861/Sjh2Bz/fpYwFj/HgbIklLC9rvAwDNQbAAgujgQcsDr75q+5T4b9s+eLAFjKuuslzQaDU1ttnJxx9LH31kPRqHDvmej42VRo70jWacemojr8UAQPCFLFhMmTJFb7/9tlatWqU2bdrojDPO0BNPPKEBAwYEvTCgJdq/X3rvPQsZ//63VFXle27IEGnCBDtGjGjE9FV/ZWVSXp4vaKxeHfh8WpqNYpx7rt3260cTKICwCVmwOP/883XVVVfp1FNP1aFDh/TQQw9p+fLlWrlypZKTk4NaGNDS7d0rvf22XS757DPbedUrM9Nmmk6YYDkgMbGJ33zTJl/ImDXLNk3z17Wr7WWSm2vHoEEEDQAhE7ZLITt37lTnzp2Vn5+vMWPGBLUwoDX57jsbwZg5U/rgg8DLJe3a2ZpZEybYGlodOjTxm1dX205rH31kYWPePN9mKF4dO1rQ8IaNoUObOGQCAPULW7BYt26d+vXrp2XLlmnIkCFHPV9ZWalKv/8BlpaWKisri2CBiFZZaSMYM2dan2Zxse+52Fjb0v2SSyxoHNdEkIoKW91r9mwpP9/6M/xnm0hSerr9oNxcCxsnn0yPBoDjFpZgUVNTo0suuUR79+7VnDlz6nzNo48+qscee+yo8wQLRIuaGmnhQgsYM2dKy5YFPt+/v68/8+yzj3Pj1KoqawTNz7fjiy8Ch0wkKTnZGkBHjbLj9NOlTp2O99cCEGXCEixuueUWvf/++5ozZ4569OhR52sYsQACbdhgM01nzrQBB/++jLg46YwzfEFj+HAb4WiyQ4dsR1Zv0Pj8c982r/769vUFjVGjrPuUUQ0AdQh5sLj99ts1c+ZMzZ49Wzk5OSEpDIh0JSU2EeSjj+xYty7w+Q4dpHHjfEEjK+s4f1B1ta0GOneu71i16ujXJSdLp51moxmMagDwE7Jg4TiO7rjjDs2YMUN5eXnq169fyAoDos2GDb6Q8cknvl1YvQYO9M02HT36OJpA/X33nfVpzJ1rzaBffXX0D5RsNbBTTvEdI0ZI7ds34wcDaI1CFixuvfVWvfrqq5o5c2bA2hVpaWlq06ZNUAsDotmhQ9LXX/uCxldfWb+GvxNP9M02HTOmkbuy1qexoxqSrQTmHzaGD7flyQFErJAFC0898+RffPFFXXvttUEtDIDP3r3Sp5/abNO8vLr/5p9wgi9o5OZK3bo184eWlNg01wUL7Fi48OidW70GDPCNaJx8sjRsGCMbQARhSW8gwm3fbj2Z3v7M2rNNJOvN9IaM0aPtqkaz19D67rvAsLFggS3mVZeePaWTTrJj2DC7zclhIS+gFSJYAFFm925f0Jg9W1qy5OhLJ127Bk4CGTFCasQVzIbt3GmjGQsXWtD45hupoKDu16amWsjwBo2TTrINV5KSglAIgFAhWABRrqREmjPHN6KxaFHgHmeSzSw96aTAsBGUUQ3Jrt0sXWoJ55tv7Hb58sDNVbxiY21BjyFDAo8+fY5zvi2AYCNYAAhw4IANKPj3Zm7bdvTrvKMa3hmnJ59sS5IHxcGD1hziDRrew39Pen9JSdY4UjtwZGVxOQUIM4IFgGNyHKmwMDBoLF589KiGx2N/20eM8E0COekkqW3bIBZSXGyjGf7HihVHL1PulZJiAWPwYCtu4EC7zc5mfxQgRAgWAJrswAG7ZOINGl9/LW3efPTrYmJsM1X/pS2GDQtSv4ZXdbW0cePRgWPVqqPTj1ebNjY75YQTAo9+/aSEhCAWB0QfggWAoNi2LbAvc/78ui+hxMbaIIL/bNOhQ6W0tCAXVFUlrV1r02BWrrS1N7791s7V1b/hLa5PH9/oxoAB1tPRv7/tCstlFaBBBAsAIVNc7FvWwhs2du6s+7U5Ob4JIN7JIL16heBv+aFDNhPFGzT8j7Ky+r8uPd0XMvyPfv2C2FwCtH4ECwBh4zh2ycQbNr75xo6iorpfn5Zmoxn+YWPw4CBfSvEvrrjYFzJWrbLRjTVrrMnkWP/7y8z0BY2+fe3o08eO5OQQFAu0XAQLAK7bvTtwxuk331hP5sGDR782Jsb+Xg8e7Jv8MXiw/U0PWXvEgQO2kuiaNUcf9Q3BeHXr5gsa/qGjb18bBQEiDMECQItUVRU449QbOHbtqvv1cXHWEuENGt7Q0bt3iJe42LPHN7KxerUFkHXr7Pa77479tRkZvpGNnBwr1nvbowdb06NVIlgAaDUcR9qxI3Cmqfd+fe0RSUnWh+mdaeq9379/GBbx/O47Cxj+YWPdOjvq6mz1FxdnS53XDhze24wMmknRIhEsALR63t6N2oFj5cr6l7jweOxvtH/Y8N7PyAhD0fv2SRs2WMgoKLD7/rf1zVzxatfO1uPo1avu286dCR5wBcECQMSqrvZNAFm1KvB27976v65Tp8CZpv362W2fPiFqHK2tpkbauvXosOG93bKl4e/Rpo2NeNQOHD172pGZyaUWhATBAkDU8V5S8Q8b3vuFhfV/ncdjq4T7hw3vba9eUnx8mH6BigordONG2zG29u2WLceexSJZF2z37hYysrJ8gcN7ZGXZdvaMeqCJCBYA4Ke83Pow/Webrl1rfZklJfV/XVycXVrx9mJ6J354+zLDMtLhVVVl14bqCh5FRXbUNeWmtuRkCxk9eljQ6NEj8MjKsjnBhA/4IVgAQCM4js1I8YYNb+Dw3tbXy+HVvXtg6PA/OnQIz+9wRE2NtH27BYzCwrqPhqbReiUnBwYN7/3u3X1Hx47szRJFCBYA0Ew1Nba21tq1gRNAvEdp6bG/Pi3NRjXqOnr1CuJGbk1x4ICNehQW2m1Rkd16j6KihqfTesXH23oe/mGj9pGZ6dIvimAjWABACDmOLQDmHzT8g0dDs04lm+BRV+Dw9mKGfNpsffbvt36O2sGjqMjOb9lizSyNlZ5uAaNbt7pvvffDel0JTUWwAAAXlZdb60NBQeDhPXesvg6vLl0sZNR3BH2Dt6aoqrL05A0a/kdxse/+/v2N/57p6YGho1s3qWtXO7z3u3Wj/8MlBAsAaMH27Kk7cHj7McvLG/4eqam+kJGV5Tu8kz+6d3d5t3jHsQRVXGzTbI9121Azi7/ExMCw4X+/SxffbZcujIIEEcECAFopx7E2h02bfEdhYeDj+pZA9+fx2N9W/9DhHzx69LC/wa4ve+E41rBSO3Bs2+a79d4/1kIldUlJCQwa3qP2uc6d2ViuAQQLAIhg5eW+sFFY6JsI4p11WlQkVVY2/H1iYuxvbPfugRM/at+2mH/4V1QEBg3/+1u32qwY79GYN8BfcrIFjM6dfWGjvvsZGVE3I4ZgAQBRzDuNtq7A4T2Ki6VDhxr3/Tp0CJzo4T38H3fu3AJGP7y8l2H8g8a2bXXfP54QEhNj4aJzZ1vStfZt7XPt27f6IEKwAAAcU02NTe7YvNk3CcR763+/sf2XMTH2D3r/4OE9vG0Q3bq1sAAiWQgpK7M3w3ts317/4927m/4zYmMtiHhDR8eODd8mJgb/d20GggUAoNm8//D3ho3i4rqPbdtsD5fG8HgsXPhP9PA//HsxW+QSGAcP2nDQzp127Nhx9H3/26b2hXilpFjA8D8yMup/nJER0vXnCRYAgLCprra/o96Zpt7AsWVLYB/m9u02UtJYKSm+Xsvakz78z3Xp0uL+ge/jDSI7dvgCSe3b2ucam9JqS031BY0vvgjq0BDBAgDQ4ngDiH/Y8D/8z1VUNO17p6f7eisbuk1JacFLYdTU2DCRN3Ds3m3Hrl2+w/+x93n/P+Xt2tnlnSBqyt/vlnSlCwAQwWJjfSMNx+Jte/D2WPrf1nXu4EG74rB3r20s15CkpMCgUbvXsva5sK6CGhNjzZ7t29sWu41RXW2/vDdo7NsX0hIbwogFAKDVchxbcMw7wcO/17Kuc41ZfKy2lJSjQ4e3z7Ku3svk5BY8InKcGLEAAEQFj8emw3boIJ1wQsOvLy8PDB119Vz63z940EZPysqkDRsaV1NS0tFho1Ono3st/e+32B6R40CwAABEjeRk36ZvDfHOiqkvfNTVg1lRYYd3Jk1jtWt3dOjIyDj20VJHRggWAADUweOxptD0dKlfv8Z9TXl5/aHD23NZu/eyutraIvbts/1iGishof7Q8ctfWk+LG+ixAADAJTU1tlVKQ5M+ah9VVfV/z8RE29ctmKMZ9FgAANAKxMT4RkX69m3c1ziOjYwcK3S4eYmEYAEAQCvi8VhPRrt2Una229UcrXXvigIAAFoUggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAgaggUAAAiasO5u6jiOJNvXHQAAtA7ev9vev+PHEtZgUVZWJknKysoK548FAABBUFZWprS0tGO+xuM0Jn4ESU1NjYqLi5WSkiKPx1Pv60pLS5WVlaWioiKlpqaGq7wWh/fB8D748F4Y3gfD++DDe2FC9T44jqOysjJlZmYqJubYXRRhHbGIiYlRjx49Gv361NTUqP6AePE+GN4HH94Lw/tgeB98eC9MKN6HhkYqvGjeBAAAQUOwAAAAQdMig0ViYqIeeeQRJSYmul2Kq3gfDO+DD++F4X0wvA8+vBemJbwPYW3eBAAAka1FjlgAAIDWiWABAACChmABAACChmABAACCpsUFi2eeeUa9evVSUlKSRo4cqa+//trtksLu0UcflcfjCTgGDhzodlkhN3v2bF188cXKzMyUx+PRO++8E/C84zj6xS9+oW7duqlNmzYaN26c1q5d606xIdTQ+3Dttdce9fk4//zz3Sk2hKZMmaJTTz1VKSkp6ty5sy699FKtXr064DUVFRW67bbblJGRoXbt2umKK67Q9u3bXao4dBrzXpx99tlHfS5uvvlmlyoOjWeffVZDhw49svjTqFGj9P777x95Plo+Dw29D25/FlpUsHj99df105/+VI888ogWLVqkYcOG6bzzztOOHTvcLi3sBg8erK1btx455syZ43ZJIVdeXq5hw4bpmWeeqfP5J598Uk8//bSee+45ffXVV0pOTtZ5552nioqKMFcaWg29D5J0/vnnB3w+XnvttTBWGB75+fm67bbbNG/ePH388cc6ePCgzj33XJWXlx95zX/913/pX//6l958803l5+eruLhYl19+uYtVh0Zj3gtJuvHGGwM+F08++aRLFYdGjx49NHXqVC1cuFALFizQOeecowkTJmjFihWSoufz0ND7ILn8WXBakNNOO8257bbbjjyurq52MjMznSlTprhYVfg98sgjzrBhw9wuw1WSnBkzZhx5XFNT43Tt2tV56qmnjpzbu3evk5iY6Lz22msuVBgetd8Hx3Gca665xpkwYYIr9bhpx44djiQnPz/fcRz77x8fH++8+eabR17z7bffOpKcuXPnulVmWNR+LxzHcXJzc5277rrLvaJc0r59e2fatGlR/XlwHN/74DjufxZazIhFVVWVFi5cqHHjxh05FxMTo3Hjxmnu3LkuVuaOtWvXKjMzU71799akSZNUWFjodkmuKigo0LZt2wI+H2lpaRo5cmRUfj7y8vLUuXNnDRgwQLfccot2797tdkkhV1JSIknq0KGDJGnhwoU6ePBgwGdi4MCB6tmzZ8R/Jmq/F16vvPKKOnbsqCFDhmjy5Mnav3+/G+WFRXV1taZPn67y8nKNGjUqaj8Ptd8HLzc/C2HdhOxYdu3aperqanXp0iXgfJcuXbRq1SqXqnLHyJEj9dJLL2nAgAHaunWrHnvsMY0ePVrLly9XSkqK2+W5Ytu2bZJU5+fD+1y0OP/883X55ZcrJydH69ev10MPPaQLLrhAc+fOVWxsrNvlhURNTY3uvvtunXnmmRoyZIgk+0wkJCQoPT094LWR/pmo672QpKuvvlrZ2dnKzMzU0qVL9cADD2j16tV6++23Xaw2+JYtW6ZRo0apoqJC7dq104wZMzRo0CAtWbIkqj4P9b0PkvufhRYTLOBzwQUXHLk/dOhQjRw5UtnZ2XrjjTd0ww03uFgZWoKrrrrqyP0TTzxRQ4cOVZ8+fZSXl6exY8e6WFno3HbbbVq+fHlU9Bo1pL734qabbjpy/8QTT1S3bt00duxYrV+/Xn369Al3mSEzYMAALVmyRCUlJXrrrbd0zTXXKD8/3+2ywq6+92HQoEGufxZazKWQjh07KjY29qgO3u3bt6tr164uVdUypKenq3///lq3bp3bpbjG+xng83G03r17q2PHjhH7+bj99tv17rvv6rPPPlOPHj2OnO/atauqqqq0d+/egNdH8meivveiLiNHjpSkiPtcJCQkqG/fvhoxYoSmTJmiYcOG6fe//33UfR7qex/qEu7PQosJFgkJCRoxYoQ++eSTI+dqamr0ySefBFw3ikb79u3T+vXr1a1bN7dLcU1OTo66du0a8PkoLS3VV199FfWfj82bN2v37t0R9/lwHEe33367ZsyYoU8//VQ5OTkBz48YMULx8fEBn4nVq1ersLAw4j4TDb0XdVmyZIkkRdznoraamhpVVlZG1eehLt73oS5h/yy41jZah+nTpzuJiYnOSy+95KxcudK56aabnPT0dGfbtm1ulxZW99xzj5OXl+cUFBQ4X3zxhTNu3DinY8eOzo4dO9wuLaTKysqcxYsXO4sXL3YkOb/5zW+cxYsXO5s2bXIcx3GmTp3qpKenOzNnznSWLl3qTJgwwcnJyXEOHDjgcuXBdaz3oayszLn33nuduXPnOgUFBc6sWbOc4cOHO/369XMqKircLj2obrnlFictLc3Jy8tztm7deuTYv3//kdfcfPPNTs+ePZ1PP/3UWbBggTNq1Chn1KhRLlYdGg29F+vWrXMef/xxZ8GCBU5BQYEzc+ZMp3fv3s6YMWNcrjy4HnzwQSc/P98pKChwli5d6jz44IOOx+NxPvroI8dxoufzcKz3oSV8FlpUsHAcx/nDH/7g9OzZ00lISHBOO+00Z968eW6XFHZXXnml061bNychIcHp3r27c+WVVzrr1q1zu6yQ++yzzxxJRx3XXHON4zg25fThhx92unTp4iQmJjpjx451Vq9e7W7RIXCs92H//v3Oueee63Tq1MmJj493srOznRtvvDEiw3dd74Ek58UXXzzymgMHDji33nqr0759e6dt27bOZZdd5mzdutW9okOkofeisLDQGTNmjNOhQwcnMTHR6du3r3Pfffc5JSUl7hYeZNdff72TnZ3tJCQkOJ06dXLGjh17JFQ4TvR8Ho71PrSEzwLbpgMAgKBpMT0WAACg9SNYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoCFYAACAoPn/Fr1Gzsayu6EAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.plot(np.arange(max_training)+1, train_losses, color='blue')\n",
    "plt.plot(np.arange(max_training)+1, test_losses, color='red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "e7d5f0ae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5.634050440666685,\n",
       " 4.387224578840572,\n",
       " 3.732324101427636,\n",
       " 3.3156136558640017,\n",
       " 3.0256615601008763,\n",
       " 2.811964422060676,\n",
       " 2.647719286830129,\n",
       " 2.517327013357274,\n",
       " 2.41108289896955,\n",
       " 2.322637761473518,\n",
       " 2.247669373950127,\n",
       " 2.1831387175154915,\n",
       " 2.126850666726065,\n",
       " 2.077182910435453,\n",
       " 2.0329124948899735,\n",
       " 1.9931013699993763,\n",
       " 1.957018850911237,\n",
       " 1.924087871537498,\n",
       " 1.8938469727466771,\n",
       " 1.8659229348805848,\n",
       " 1.8400107568568118,\n",
       " 1.8158587973525266,\n",
       " 1.7932576019660427,\n",
       " 1.7720314009705054,\n",
       " 1.7520315678371967,\n",
       " 1.7331315349857963,\n",
       " 1.7152228047331706,\n",
       " 1.698211791932332,\n",
       " 1.6820173043109274,\n",
       " 1.6665685161844546,\n",
       " 1.65180332711176,\n",
       " 1.6376670232765635,\n",
       " 1.6241111787190878,\n",
       " 1.6110927479420472,\n",
       " 1.5985733122318575]"
      ]
     },
     "execution_count": 156,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce360760",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "younger",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
